{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Methods for Multiscale Comparative Connectomics\n",
    "\n",
    "This demo shows you how to use methods in `graspologic` to analyze patterns\n",
    "in brain connectivity in connectomics datasets. We specifically demonstrate\n",
    "methods for identifying differences in edges and vertices across subjects. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import warnings\n",
    "pd.options.mode.chained_assignment = None\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the Duke mouse brain dataset\n",
    "\n",
    "Dataset of 32 mouse connectomes derived from whole-brain diffusion\n",
    "magnetic resonance imaging of four distinct mouse genotypes:\n",
    "BTBR T+ Itpr3tf/J (BTBR), C57BL/6J(B6), CAST/EiJ (CAST), and DBA/2J (DBA2).\n",
    "For each strain, connectomes were generated from eight age-matched mice\n",
    "(N = 8 per strain), with a sex distribution of four males and four females.\n",
    "Each connectome was parcellated using asymmetric Waxholm Space, yielding a\n",
    "vertex set with a total of 332 regions of interest (ROIs) symmetrically\n",
    "distributed across the left and right hemispheres. Within a given\n",
    "hemisphere, there are seven superstructures consisting up multiple ROIs,\n",
    "resulting in a total of 14 distinct communities in each connectome."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from graspologic.datasets import load_mice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the full mouse dataset\n",
    "mice = load_mice()\n",
    "\n",
    "# Stack all adjacency matrices in a 3D numpy array\n",
    "graphs = np.array(mice.graphs)\n",
    "\n",
    "# Get sample parameters\n",
    "n_subjects = mice.meta[\"n_subjects\"]\n",
    "n_vertices = mice.meta[\"n_vertices\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split the set of graphs by genotype\n",
    "btbr = graphs[mice.labels == \"BTBR\"]\n",
    "b6 = graphs[mice.labels == \"B6\"]\n",
    "cast = graphs[mice.labels == \"CAST\"]\n",
    "dba2 = graphs[mice.labels == \"DBA2\"]\n",
    "\n",
    "connectomes = [btbr, b6, cast, dba2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Identifying Signal Edges\n",
    "\n",
    "The simplest approach for comparing connectomes is to treat them as a _bag of edges_ without considering interactions between the edges.\n",
    "Serially performing univariate statistical analyses at each edge enables the discovery of _signal edges_ whose neurological connectivity differs across categorical or dimensional phenotypes.\n",
    "Here, we demonstrate the possibility of using Distance Correlation (`Dcorr`), a nonparametric universally consistent test, to detect changes in edges.\n",
    "\n",
    "In this model, we assume that each edge in the connectome is independently and\n",
    "identically sampled from some distribution $F_i$, where $i$ represents the group\n",
    "to which the given connectome belongs. In this setting the groups are the mouse\n",
    "genotypes. `Dcorr` allows us to test the following hypothesis:\n",
    "\n",
    "\\begin{align*}\n",
    "H_0:\\ &F_1 = F_2 = \\ldots F_k \\\\\n",
    "H_A:\\ &\\exists \\ j \\neq j' \\text{ s.t. } F_j \\neq F_{j'}\n",
    "\\end{align*}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hyppo.ksample import KSample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since the connectomes in this dataset are undirected, we only need to do edge\n",
    "comparisons on the upper triangle of the adjacency matrices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make iterator for traversing the upper triangle of the connectome\n",
    "indices = zip(*np.triu_indices(n_vertices, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "edge_pvals = []\n",
    "\n",
    "for roi_i, roi_j in indices:\n",
    "    \n",
    "    # Get the (i,j)-th edge for each connectome\n",
    "    samples = [genotype[:, roi_i, roi_j] for genotype in connectomes]\n",
    "    \n",
    "    # Calculate the p-value for the (i,j)-th edge\n",
    "    try:\n",
    "        statistic, pvalue = KSample(\"Dcorr\").test(*samples, reps=10000, workers=-1)\n",
    "    except ValueError:\n",
    "        # A ValueError is thrown when any of the samples have equal edge \n",
    "        # weights (i.e. one of the inputs has 0 variance)\n",
    "        statistic = np.nan\n",
    "        pvalue = 1\n",
    "\n",
    "    edge_pvals.append([roi_i, roi_j, statistic, pvalue])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Connectomes are a high-dimensional dataype.\n",
    "Thus, statistical tests on components of the connectome (e.g. edge and vertices)\n",
    "results in multiple comparisons.\n",
    "We recommend correcting for multiple comparisons using the Holm-Bonferroni (HB)\n",
    "corection.\n",
    "\n",
    "The algorithm is described below:\n",
    "1. Sort the p-values from lowest-to-highest $P_1, P_2, \\dots, P_n$, where $n$ is the number of tests\n",
    "2. Correct the p-value as $P_1(n), P_2(n-1), \\dots, P_n(1)$\n",
    "3. If any corrected p-value is $>1$, replace with $1$\n",
    "4. If the corrected p-value is less than a significance level $\\alpha$, reject"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the nested list to a dataframe\n",
    "edge_pvals = pd.DataFrame(edge_pvals, columns=[\"ROI_1\", \"ROI_2\", \"stat\", \"pval\"])\n",
    "\n",
    "# Correct p-values using the Holm-Bonferroni correction\n",
    "edge_pvals.sort_values(by=\"pval\", inplace=True, ignore_index=True)\n",
    "pval_rank = edge_pvals[\"pval\"].rank(ascending=False, method=\"max\")\n",
    "edge_pvals[\"holm_pval\"] = edge_pvals[\"pval\"].multiply(pval_rank)\n",
    "edge_pvals[\"holm_pval\"] = edge_pvals[\"holm_pval\"].apply(\n",
    "    lambda pval: 1 if pval > 1 else pval\n",
    ")\n",
    "\n",
    "# Test for significance using alpha=0.05\n",
    "alpha = 0.05\n",
    "edge_pvals[\"significant\"] = (edge_pvals[\"holm_pval\"] < alpha)\n",
    "\n",
    "# Get the top 10 strongest signal edges\n",
    "edge_pvals_top = edge_pvals.head(10)\n",
    "\n",
    "# Replace ROI indices with actual names\n",
    "def lookup_roi_name(index):\n",
    "    hemisphere = \"R\" if index // 166 else \"L\"\n",
    "    index = index % 166\n",
    "    roi_name = mice.atlas.query(f\"ROI == {index+1}\")[\"Structure\"].item()\n",
    "    return f\"{roi_name} ({hemisphere})\"\n",
    "\n",
    "edge_pvals_top[\"ROI_1\"] = edge_pvals_top[\"ROI_1\"].apply(lookup_roi_name)\n",
    "edge_pvals_top[\"ROI_2\"] = edge_pvals_top[\"ROI_2\"].apply(lookup_roi_name)\n",
    "\n",
    "edge_pvals_top.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that none of the edges achieve significance at $\\alpha=0.05$ following \n",
    "Holm-Bonferroni correction.\n",
    "We might expect this, given that we are correcting for $N=54,946$ tests.\n",
    "Instead of the magnitude, the **ranking of the p-values** can be used to determine signal edges."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Identifying Signal Vertices\n",
    "\n",
    "A sample of connectomes can be jointly embedded in a low-dimensional Euclidean space using the omnibus embedding (`omni`).\n",
    "A host of machine learning tasks can be accomplished with this joint embedded representation of the connectome, such as clustering or classification of vertices.\n",
    "Here, we use the embedding to formulate a statistical test that can be used to identify vertices that are strongly associated with given phenotypes.\n",
    "According to a Central Limit Theorem for `omni`, these latent positions are universally consistent and asymptotically normal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from graspologic.embed import OmnibusEmbed\n",
    "from graspologic.plot import pairplot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Jointly embed graphs using omnibus embedding\n",
    "embedder = OmnibusEmbed()\n",
    "omni_embedding = embedder.fit_transform(graphs)\n",
    "print(omni_embedding.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For each of the 32 mice, `omni` embeds each vertex in the connectome to a\n",
    "latent position vector $x_i \\in \\mathbb{R}^5$.\n",
    "We test for differences in the distribution of vertex latent positions using \n",
    "the `MGC` independence test implemented in `hyppo`.\n",
    "This test essentially acts as a nonparametric MANOVA test."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with warnings.catch_warnings():\n",
    "    warnings.filterwarnings(\"ignore\", module=\"hyppo\")\n",
    "    warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n",
    "    vertex_pvals = []\n",
    "\n",
    "    for vertex_i in range(n_vertices):\n",
    "    \n",
    "    # Get the embedding of the i-th vertex for \n",
    "        samples = [\n",
    "            omni_embedding[mice.labels==genotype, vertex_i, :] \n",
    "            for genotype in np.unique(mice.labels)\n",
    "        ]\n",
    "    \n",
    "    # Calculate the p-value for the i-th vertex\n",
    "        statistic, pvalue, _ = KSample(\"MGC\").test(\n",
    "            *samples, reps=250, workers=-1\n",
    "        )\n",
    "        vertex_pvals.append([vertex_i, statistic, pvalue])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Important:** When running on data, make sure `reps` $\\geq \\frac{|V|}{\\alpha}$ where $|V|$ is the cardinality of the vertex set and $\\alpha$ is the significance level for your test.\n",
    "Again, this has to do with p-value correction:\n",
    "because `MGC` is a permutation test, you need many replications to ensure accurate computation of the p-value.\n",
    "Because this notebook is primarily for demonstration purposes, we use a small value of `reps=250`. For this dataset, a more appropriate number of permutations would be around $10^6$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the nested list to a dataframe\n",
    "vertex_pvals = pd.DataFrame(vertex_pvals, columns=[\"ROI\", \"stat\", \"pval\"])\n",
    "\n",
    "# Correct p-values using the Holm-Bonferroni correction\n",
    "vertex_pvals.sort_values(by=\"pval\", inplace=True, ignore_index=True)\n",
    "pval_rank = vertex_pvals[\"pval\"].rank(ascending=False, method=\"max\")\n",
    "vertex_pvals[\"holm_pval\"] = vertex_pvals[\"pval\"].multiply(pval_rank)\n",
    "vertex_pvals[\"holm_pval\"] = vertex_pvals[\"holm_pval\"].apply(\n",
    "    lambda pval: 1 if pval > 1 else pval\n",
    ")\n",
    "\n",
    "# Test for significance using alpha=0.05\n",
    "alpha = 0.05\n",
    "vertex_pvals[\"significant\"] = (vertex_pvals[\"holm_pval\"] < alpha)\n",
    "\n",
    "# Get the top 10 strongest signal edges\n",
    "vertex_pvals_top = vertex_pvals.head(10)\n",
    "\n",
    "# Replace ROI indices with actual names\n",
    "vertex_pvals_top[\"ROI\"] = vertex_pvals_top[\"ROI\"].apply(lookup_roi_name)\n",
    "\n",
    "vertex_pvals_top.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use pairplots to visualize the embeddings of specific vertices.\n",
    "Below are pairsplots of the corpus callosum from the left and right hemispheres."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "left = pairplot(\n",
    "    omni_embedding[:, 121 - 1, :3], \n",
    "    mice.labels, \n",
    "    palette=[\"#e7298a\", \"#1b9e77\", \"#d95f02\", \"#7570b3\"],\n",
    "    title=\"Left Corpus Callosum\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "right = pairplot(\n",
    "    omni_embedding[:, 121 + 166 - 1, :3], \n",
    "    mice.labels, \n",
    "    palette=[\"#e7298a\", \"#1b9e77\", \"#d95f02\", \"#7570b3\"],\n",
    "    title=\"Right Corpus Callosum\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Whole-brain Comparisons\n",
    "\n",
    "We can use the results of the omnibus embedding to perform whole-brain comparisons\n",
    "across subjects from different phenotypes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from graspologic.embed import ClassicalMDS\n",
    "from graspologic.plot import heatmap"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Two-dimensional representations of each connectome were obtained by using \n",
    "Classical Multidimensional Scaling (`cmds`) to reduce the dimensionality \n",
    "of the embeddings obtained by `omni`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Further reduce embedding dimensionality using cMDS\n",
    "cmds = ClassicalMDS(2)\n",
    "cmds_embedding = cmds.fit_transform(omni_embedding)\n",
    "cmds_embedding = pd.DataFrame(cmds_embedding, columns=[\"Dimension 1\", \"Dimension 2\"])\n",
    "cmds_embedding[\"Genotype\"] = mice.labels\n",
    "\n",
    "# Embedding with BTBR\n",
    "sns.scatterplot(\n",
    "    x=\"Dimension 1\",\n",
    "    y=\"Dimension 2\",\n",
    "    hue=\"Genotype\",\n",
    "    data=cmds_embedding,\n",
    "    palette=[\"#e7298a\", \"#1b9e77\", \"#d95f02\", \"#7570b3\"],\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find the dissimilarity between subjects' connectomes using the cMDS embedding\n",
    "dis = cmds.dissimilarity_matrix_\n",
    "scaled_dissimilarity = dis / np.max(dis)\n",
    "\n",
    "heatmap(scaled_dissimilarity,\n",
    "        context=\"paper\",\n",
    "        inner_hier_labels=mice.labels,)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that mice from the same genotype are most similar to each other, as expected."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}